Предсказание музыкального жанра
- 1 Описание
- 2 Загрузка и ознакомление с данными для обучения
- 3 Исследовательский анализ тренировочных данных (EDA)
- 4 Предобработка
- 5 Разработка новых синтетических признаков
- 6 Проверка на мультиколлинеарность
- 7 Отбор финального набора обучающих признаков
- 8 Подготовка данных для работы
- 9 Создание конвеера
- 10 Нахождение лучшей модели
- 11 Оценка модели на валидационной выборке
- 12 Получение предсказайний на тестовых данных
- 13 Анализ важности признаков в предсказаниях модели
- 14 Общий вывод
Описание¶
Постановка задачи¶
Стриминговый сервис "МиФаСоль".
Сервис расширяет работу с новыми артистами и музыкантами, в связи с чем возникла задача правильно классифицировать новые музыкальные треки, чтобы улучшить работу рекомендательной системы.
Был подготовлен датасет, в котором собраны некоторые характеристики музыкальных произведений и их жанры.
Задача: разработать модель, позволяющую классифицировать музыкальные произведения по жанрам.
Описание полей данных¶
instance_id- Уникальный идентификатор трекаtrack_name- Название трекаacousticness- Мера уверенности от 0,0 до 1,0 в том, что трек является акустическим. 1,0 означает высокую степень уверенности в том, что трек является акустическим.danceability- Танцевальность описывает, насколько трек подходит для танцев, основываясь на сочетании музыкальных элементов, включая темп, стабильность ритма, силу ударов и общую регулярность. Значение 0,0 означает наименьшую танцевальность, а 1,0 - наибольшую танцевальность.duration_ms- Продолжительность трека в миллисекундах.energy- Энергия это показатель от 0,0 до 1,0, представляющий собой меру интенсивности и активности. Как правило, энергичные композиции ощущаются как быстрые, громкие и шумные. Например, дэт-метал обладает высокой энергией, в то время как прелюдия Баха имеет низкую оценку этого параметраinstrumentalness- Определяет, содержит ли трек вокал. Звуки "Ooh" и "aah" в данном контексте рассматриваются как инструментальные. Рэп или разговорные треки явно являются "вокальными". Чем ближе значение инструментальности к 1,0, тем больше вероятность того, что трек не содержит вокалаkey- базовый ключ (нота) произведенияliveness- Определяет присутствие аудитории в записи. Более высокие значения liveness означают увеличение вероятности того, что трек был исполнен вживую. Значение выше 0,8 обеспечивает высокую вероятность того, что трек исполняется вживуюloudness- Общая громкость трека в децибелах (дБ)mode- Указывает на модальность (мажорную или минорную) трекаspeechiness- Речевой характер определяет наличие в треке разговорной речи. Чем более исключительно речевой характер носит запись (например, ток-шоу, аудиокнига, поэзия), тем ближе значение атрибута к 1,0. Значения выше 0,66 характеризуют треки, которые, вероятно, полностью состоят из разговорной речи. Значения от 0,33 до 0,66 характеризуют треки, которые могут содержать как музыку, так и речь, как в виде фрагментов, так и в виде слоев, включая такие случаи, как рэп-музыка. Значения ниже 0,33, скорее всего, представляют музыку и другие неречевые треки.tempo- Темп трека в ударах в минуту (BPM). В музыкальной терминологии темп представляет собой скорость или темп данного произведения и напрямую зависит от средней продолжительности тактовobtained_date- дата загрузки в сервисvalence- Показатель от 0,0 до 1,0, характеризующий музыкальный позитив, передаваемый треком. Композиции с высокой валентностью звучат более позитивно (например, радостно, весело, эйфорично), а композиции с низкой валентностью - более негативно (например, грустно, депрессивно, сердито)music_genre- Музыкальный жанр трека (целевой признак).
Основные этапы работ¶
- загрузка и ознакомление с данными,
- предварительная обработка,
- полноценный разведочный анализ,
- разработка новых синтетических признаков,
- проверка на мультиколлинеарность,
- отбор финального набора обучающих признаков,
- выбор и обучение моделей,
- итоговая оценка качества предсказания лучшей модели,
- анализ важности ее признаков.
Требуемые пакеты¶
# installs
!pip install -q phik
!pip install -q scikit-learn==1.3
!pip install -q pandas-profiling[notebook]
!pip install -q -U imbalanced-learn
!pip install -q feature_engine
!pip install -q catboost
!pip install -q shap
WARNING: visions 0.7.5 does not provide the extra 'type-image-path'
Используемые компоненты¶
# import libraries
import os
import re
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams, rcParamsDefault
import seaborn as sns
import numpy as np
import phik
from ydata_profiling import ProfileReport
from imblearn.pipeline import Pipeline
from imblearn.pipeline import make_pipeline as make_imblearn_pipeline
from imblearn.base import FunctionSampler
from feature_engine.selection import DropCorrelatedFeatures
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import StandardScaler
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, IterativeImputer
from sklearn.compose import ColumnTransformer, make_column_selector, make_column_transformer
from sklearn.ensemble import IsolationForest
from sklearn.ensemble import RandomForestClassifier
from catboost import CatBoostClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report
import shap
Основные настройки¶
rcParams['figure.figsize'] = 8,7
%config InlineBackend.figure_format = 'svg'
%matplotlib inline
factor = 1
default_dpi = rcParamsDefault['figure.dpi']
rcParams['figure.dpi'] = default_dpi*factor
pd.options.display.max_colwidth = 50
pd.options.display.max_columns = 20
pd.options.display.max_columns = 50
shap.initjs()
Основные константы проекта¶
PATH_DATASET_TRAIN = 'kaggle_music_genre_train.csv' # информация (~20000) музыкальных треках, которые будут использоваться в качестве обучающих данных.
PATH_DATASET_TEST = 'kaggle_music_genre_test.csv' # информация (~5000) музыкальных треках, которые будут использоваться в качестве тестовых данных
RANDOM_STATE = 69
IS_BEST_MODEL_NOT_FOUND = False # Если лучшая модель не найдена, или хочется повторно найти ее тогда - True
Загрузка и ознакомление с данными для обучения¶
def load_dataset(path):
if os.path.exists(path):
df = pd.read_csv(path)
print(f'Файл "{path}" успешно загружен.')
display(df.describe())
display(df)
df.info()
return df
else:
raise SystemExit(f'Ошибка, файл "{path}" не найден!')
train_df = load_dataset(PATH_DATASET_TRAIN)
Файл "kaggle_music_genre_train.csv" успешно загружен.
| instance_id | acousticness | danceability | duration_ms | energy | instrumentalness | liveness | loudness | speechiness | tempo | valence | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 20394.000000 | 20394.000000 | 20394.000000 | 2.039400e+04 | 20394.000000 | 20394.000000 | 20394.000000 | 20394.000000 | 20394.000000 | 19952.000000 | 20394.000000 |
| mean | 55973.846916 | 0.274783 | 0.561983 | 2.203754e+05 | 0.625276 | 0.159989 | 0.198540 | -8.552998 | 0.091352 | 120.942522 | 0.464588 |
| std | 20695.792545 | 0.321643 | 0.171898 | 1.267283e+05 | 0.251238 | 0.306503 | 0.166742 | 5.499917 | 0.097735 | 30.427590 | 0.243387 |
| min | 20011.000000 | 0.000000 | 0.060000 | -1.000000e+00 | 0.001010 | 0.000000 | 0.013600 | -44.406000 | 0.022300 | 34.765000 | 0.000000 |
| 25% | 38157.250000 | 0.015200 | 0.451000 | 1.775170e+05 | 0.470000 | 0.000000 | 0.097300 | -10.255750 | 0.035600 | 95.921750 | 0.272000 |
| 50% | 56030.000000 | 0.120000 | 0.570000 | 2.195330e+05 | 0.666000 | 0.000144 | 0.130000 | -7.052000 | 0.049050 | 120.012500 | 0.457000 |
| 75% | 73912.750000 | 0.470000 | 0.683000 | 2.660000e+05 | 0.830000 | 0.084475 | 0.253000 | -5.054000 | 0.095575 | 141.966250 | 0.653000 |
| max | 91758.000000 | 0.996000 | 0.978000 | 4.497994e+06 | 0.999000 | 0.996000 | 1.000000 | 3.744000 | 0.942000 | 220.041000 | 0.992000 |
| instance_id | track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 25143.0 | Highwayman | 0.48000 | 0.670 | 182653.0 | 0.351 | 0.017600 | D | 0.1150 | -16.842 | Major | 0.0463 | 101.384 | 4-Apr | 0.450 | Country |
| 1 | 26091.0 | Toes Across The Floor | 0.24300 | 0.452 | 187133.0 | 0.670 | 0.000051 | A | 0.1080 | -8.392 | Minor | 0.0352 | 113.071 | 4-Apr | 0.539 | Rock |
| 2 | 87888.0 | First Person on Earth | 0.22800 | 0.454 | 173448.0 | 0.804 | 0.000000 | E | 0.1810 | -5.225 | Minor | 0.3710 | 80.980 | 4-Apr | 0.344 | Alternative |
| 3 | 77021.0 | No Te Veo - Digital Single | 0.05580 | 0.847 | 255987.0 | 0.873 | 0.000003 | G# | 0.3250 | -4.805 | Minor | 0.0804 | 116.007 | 4-Apr | 0.966 | Hip-Hop |
| 4 | 20852.0 | Chasing Shadows | 0.22700 | 0.742 | 195333.0 | 0.575 | 0.000002 | C | 0.1760 | -5.550 | Major | 0.0487 | 76.494 | 4-Apr | 0.583 | Alternative |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 20389 | 47396.0 | O Pato | 0.71900 | 0.725 | -1.0 | 0.483 | 0.000000 | NaN | 0.0797 | -13.314 | Minor | 0.0438 | 87.413 | 4-Apr | 0.942 | Jazz |
| 20390 | 44799.0 | Mt. Washington | 0.19000 | 0.482 | 198933.0 | 0.362 | 0.005620 | F# | 0.0913 | -10.358 | Minor | 0.0299 | 76.879 | 4-Apr | 0.174 | Rock |
| 20391 | 33350.0 | Original Prankster | 0.00061 | 0.663 | 220947.0 | 0.886 | 0.000025 | D | 0.2840 | -4.149 | Major | 0.0358 | 146.803 | 4-Apr | 0.942 | Alternative |
| 20392 | 77920.0 | 4Peat | 0.00310 | 0.914 | 162214.0 | 0.515 | 0.000000 | C# | 0.1050 | -9.934 | Major | 0.3560 | 150.016 | 4-Apr | 0.215 | Rap |
| 20393 | 86375.0 | Trouble (feat. MC Spyder) | 0.05350 | 0.717 | 271885.0 | 0.983 | 0.491000 | D# | 0.1080 | -1.615 | Minor | 0.0942 | 128.004 | 4-Apr | 0.077 | Electronic |
20394 rows × 16 columns
<class 'pandas.core.frame.DataFrame'> RangeIndex: 20394 entries, 0 to 20393 Data columns (total 16 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 instance_id 20394 non-null float64 1 track_name 20394 non-null object 2 acousticness 20394 non-null float64 3 danceability 20394 non-null float64 4 duration_ms 20394 non-null float64 5 energy 20394 non-null float64 6 instrumentalness 20394 non-null float64 7 key 19659 non-null object 8 liveness 20394 non-null float64 9 loudness 20394 non-null float64 10 mode 19888 non-null object 11 speechiness 20394 non-null float64 12 tempo 19952 non-null float64 13 obtained_date 20394 non-null object 14 valence 20394 non-null float64 15 music_genre 20394 non-null object dtypes: float64(11), object(5) memory usage: 2.5+ MB
Данные загружены успешно, первое что бросается в глаза это то что колонка instance_id, которую сделаем в процессе работ индексной, имеет не логичный тип float64 (а по хорошему должен быть int64), а также в колонках не одинаковое значения записей, что скорее всего говорит нам о пропусках.
Исследовательский анализ тренировочных данных (EDA)¶
Поверхностная оценка¶
train_df.isna().sum(axis=0)
instance_id 0 track_name 0 acousticness 0 danceability 0 duration_ms 0 energy 0 instrumentalness 0 key 735 liveness 0 loudness 0 mode 506 speechiness 0 tempo 442 obtained_date 0 valence 0 music_genre 0 dtype: int64
# теперь тоже самое, но в процентах, и только проблемные колонки
def show_miss(data):
'''Отобразить только колонки с пропусками и количество пропусков.'''
miss = ((data.isna().sum() / len(data)).round(3) * 100)
miss = miss[miss > 0.0]
if len(miss) == 0:
print("Пропуски отсутствуют.")
else:
print("Оставшиеся пропуски (%):")
display(miss)
show_miss(train_df)
Оставшиеся пропуски (%):
key 3.6 mode 2.5 tempo 2.2 dtype: float64
Профайлинг (углубленный EDA анализ)¶
Отчет¶
ProfileReport(train_df)
Summarize dataset: 0%| | 0/5 [00:00<?, ?it/s]
Generate report structure: 0%| | 0/1 [00:00<?, ?it/s]
Render HTML: 0%| | 0/1 [00:00<?, ?it/s]
Выводы по отчету¶
Выявленные существенные проблемы:
- Сильная корреляция
acousticness,energy,loudness
- Пропуски
key(735) - 3.6% - (основная нота) произведения - категорияmode(506) - 2.5% - модальность (мажорную или минорную) - категорияtempo(442) - 2.2% - темп трека в ударах в минуту (BPM) - число
Пропусков не так и много, но неприятно то что между ними нет особо сильных пересечений, судя по отчету (Missing values -> Matrix, Heatmap), а значит если мы удалим пропуски эти проценты все суммируются и будет уже ~8.3% от всех данных (а так, есть очень мало значений пересечений пропусков в колонках {key, tempo}, {mode, tempo}, {mode, key}).
- Дизбаланс в признаках
obtained_date- дата загрузки - категория (похоже что этот признак вообще бесполезен для обучения, и нелогичен в использовании)
- Имеются дубликаты
- Имеются выбросы
- В
duration_msимеются отрицательные значения, по факту это выброс, но он не случаен и имеет константное представление выраженное одним числом, можно предположить что это пропуски так были заполнены в данных.
Предобработка¶
Установка индексной колонки¶
print("Количество дубликатов в колонке instance_id:",
train_df['instance_id'].duplicated(keep=False).sum()) # проверка что индексы будут уникальными
Количество дубликатов в колонке instance_id: 0
# изменим тип колонки, с float64 на int64
train_df.instance_id = train_df.instance_id.astype(np.int64)
train_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 20394 entries, 0 to 20393 Data columns (total 16 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 instance_id 20394 non-null int64 1 track_name 20394 non-null object 2 acousticness 20394 non-null float64 3 danceability 20394 non-null float64 4 duration_ms 20394 non-null float64 5 energy 20394 non-null float64 6 instrumentalness 20394 non-null float64 7 key 19659 non-null object 8 liveness 20394 non-null float64 9 loudness 20394 non-null float64 10 mode 19888 non-null object 11 speechiness 20394 non-null float64 12 tempo 19952 non-null float64 13 obtained_date 20394 non-null object 14 valence 20394 non-null float64 15 music_genre 20394 non-null object dtypes: float64(10), int64(1), object(5) memory usage: 2.5+ MB
# назначим instance_id индексной колонкой
train_df = train_df.set_index('instance_id', drop=True)
train_df.head()
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 25143 | Highwayman | 0.4800 | 0.670 | 182653.0 | 0.351 | 0.017600 | D | 0.115 | -16.842 | Major | 0.0463 | 101.384 | 4-Apr | 0.450 | Country |
| 26091 | Toes Across The Floor | 0.2430 | 0.452 | 187133.0 | 0.670 | 0.000051 | A | 0.108 | -8.392 | Minor | 0.0352 | 113.071 | 4-Apr | 0.539 | Rock |
| 87888 | First Person on Earth | 0.2280 | 0.454 | 173448.0 | 0.804 | 0.000000 | E | 0.181 | -5.225 | Minor | 0.3710 | 80.980 | 4-Apr | 0.344 | Alternative |
| 77021 | No Te Veo - Digital Single | 0.0558 | 0.847 | 255987.0 | 0.873 | 0.000003 | G# | 0.325 | -4.805 | Minor | 0.0804 | 116.007 | 4-Apr | 0.966 | Hip-Hop |
| 20852 | Chasing Shadows | 0.2270 | 0.742 | 195333.0 | 0.575 | 0.000002 | C | 0.176 | -5.550 | Major | 0.0487 | 76.494 | 4-Apr | 0.583 | Alternative |
Устранение пропусков¶
- Пропуски в колонке
key
# Сгрупперуем пропуски по жанрам
train_df[train_df.key.isna()].music_genre.value_counts()
Blues 104 Rap 100 Electronic 91 Alternative 90 Rock 79 Country 76 Anime 67 Hip-Hop 46 Jazz 42 Classical 40 Name: music_genre, dtype: int64
Пропуски имеются во всех жанрах, трудно придумать чем этот признак выражается даже если делать синтетический признак, поэтому удаляем, либо заполняем через SimpleImputer, в тренировочных данных лучше удалим.
train_df = train_df[train_df.key.notnull()]
show_miss(train_df)
Оставшиеся пропуски (%):
mode 2.5 tempo 2.2 dtype: float64
- Пропуски в колонке
mode
# Сгрупперуем пропуски по жанрам
train_df[train_df['mode'].isna()].music_genre.value_counts()
Rap 68 Blues 68 Electronic 63 Alternative 59 Rock 58 Country 51 Anime 47 Classical 33 Jazz 23 Hip-Hop 14 Name: music_genre, dtype: int64
Также как и с категорией key, поступаем также.
train_df = train_df[train_df['mode'].notnull()]
show_miss(train_df)
Оставшиеся пропуски (%):
tempo 2.2 dtype: float64
- Пропуски в колонке
tempo
# Сгрупперуем пропуски по жанрам
train_df[train_df['tempo'].isna()].music_genre.value_counts()
Blues 58 Alternative 54 Rock 53 Rap 50 Country 47 Electronic 47 Anime 36 Classical 33 Hip-Hop 27 Jazz 16 Name: music_genre, dtype: int64
#train_df = train_df[train_df['tempo'].notnull()]
#show_miss(train_df)
Заполним пропуски тут с помощью фильтра IterativeImputer далее уже в конвеере.
Удаление дубликатов¶
# определение
print("Записей до удаления:", train_df.shape[0])
train_df[train_df.duplicated(keep=False)]
Записей до удаления: 19175
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 56339 | Fuck It Up | 0.18 | 0.552 | 188909.0 | 0.686 | 0.00461 | G# | 0.328 | -9.409 | Major | 0.112 | 131.663 | 4-Apr | 0.347 | Electronic |
| 60651 | Fuck It Up | 0.18 | 0.552 | 188909.0 | 0.686 | 0.00461 | G# | 0.328 | -9.409 | Major | 0.112 | 131.663 | 4-Apr | 0.347 | Electronic |
# удаление
train_df = train_df.drop_duplicates()
print("Записей после удаления:", train_df.shape[0])
Записей после удаления: 19174
Также проверим данные на неявные дубликаты
train_df[train_df.duplicated(
['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode'],
keep=False
)].sort_values(by='track_name')
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 71962 | '75 aka Stay With You | 0.030700 | 0.712 | 371507.0 | 0.8270 | 0.795 | C# | 0.122 | -7.331 | Major | 0.0414 | 126.001 | 4-Apr | 0.3050 | Electronic |
| 77103 | '75 aka Stay With You | 0.030700 | 0.712 | 371507.0 | 0.8270 | 0.795 | C# | 0.122 | -7.331 | Major | 0.0414 | 126.001 | 4-Apr | 0.3050 | Jazz |
| 20145 | (Your Love Keeps Lifting Me) Higher & Higher | 0.176000 | 0.631 | 181067.0 | 0.6900 | 0.000 | D | 0.121 | -6.676 | Major | 0.0531 | 94.574 | 4-Apr | 0.9380 | Blues |
| 60181 | (Your Love Keeps Lifting Me) Higher & Higher | 0.176000 | 0.631 | 181067.0 | 0.6900 | 0.000 | D | 0.121 | -6.676 | Major | 0.0531 | 94.574 | 4-Apr | 0.9380 | Rock |
| 43672 | 1 on 1 | 0.000525 | 0.905 | 192840.0 | 0.6770 | 0.000 | D | 0.251 | -7.987 | Major | 0.3450 | 140.882 | 4-Apr | 0.5980 | Rap |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 63663 | wokeuplikethis* | 0.013800 | 0.785 | 235535.0 | 0.6200 | 0.000 | G# | 0.150 | -6.668 | Major | 0.2540 | 78.476 | 4-Apr | 0.4780 | Rap |
| 27588 | アシタカとサン - Piano Solo Feature | 0.975000 | 0.338 | 273400.0 | 0.1260 | 0.911 | C# | 0.107 | -14.419 | Major | 0.0284 | 79.882 | 4-Apr | 0.0648 | Classical |
| 38507 | アシタカとサン - Piano Solo Feature | 0.975000 | 0.338 | 273400.0 | 0.1260 | 0.911 | C# | 0.107 | -14.419 | Major | 0.0284 | 79.882 | 4-Apr | 0.0648 | Anime |
| 36885 | 花水木の咲く頃 - 辻井伸行 | 0.994000 | 0.270 | 209600.0 | 0.0171 | 0.921 | E | 0.145 | -31.429 | Major | 0.0410 | 70.931 | 4-Apr | 0.1470 | Anime |
| 24851 | 花水木の咲く頃 - 辻井伸行 | 0.994000 | 0.270 | 209600.0 | 0.0171 | 0.921 | E | 0.145 | -31.429 | Major | 0.0410 | 70.931 | 4-Apr | 0.1470 | Classical |
985 rows × 15 columns
Опытным путем было обнаружено что существуют идентичные категории но в каком-то случае неправильно классифицированны, а также есть немного записей с идентичными данными но у них имеются различия в названиях, и это проблема, мы не можем эти данные как-то просто отнести к како-либо категории просто, конечно самое простое что можно придумать это удалить проблемные данные в обоих вариантах.
Для начала разберемся с неправильно названными записями
train_df[train_df.duplicated(
['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode','music_genre'],
keep=False
)].sort_values(by='track_name')
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 68462 | Back To The Future | 0.018600 | 0.583 | 180000.0 | 0.841 | 0.0000 | D | 0.105 | -3.665 | Major | 0.1900 | 90.646 | 4-Apr | 0.495 | Electronic |
| 27101 | Back To The Future (feat. ProbCause) | 0.018600 | 0.583 | 180000.0 | 0.841 | 0.0000 | D | 0.105 | -3.665 | Major | 0.1900 | 90.646 | 4-Apr | 0.495 | Electronic |
| 61145 | Extreme Ways (Bourne's Legacy) | 0.000417 | 0.466 | 290827.0 | 0.828 | 0.2050 | B | 0.335 | -7.898 | Minor | 0.0518 | 104.495 | 4-Apr | 0.299 | Electronic |
| 23448 | Extreme Ways (Bourne's Legacy) - Original Version | 0.000417 | 0.466 | 290827.0 | 0.828 | 0.2050 | B | 0.335 | -7.898 | Minor | 0.0518 | 104.495 | 4-Apr | 0.299 | Electronic |
| 21594 | Forever - FuntCase Remix | 0.013400 | 0.312 | 321600.0 | 0.981 | 0.0619 | F | 0.289 | -2.788 | Minor | 0.3800 | 149.600 | 4-Apr | 0.043 | Electronic |
| 22037 | Forever - Funtcase Remix | 0.013400 | 0.312 | 321600.0 | 0.981 | 0.0619 | F | 0.289 | -2.788 | Minor | 0.3800 | 149.600 | 4-Apr | 0.043 | Electronic |
| 66646 | Many Shades Of Black | 0.019600 | 0.408 | 264627.0 | 0.728 | 0.0000 | G | 0.462 | -4.287 | Minor | 0.0340 | 95.438 | 3-Apr | 0.632 | Blues |
| 36123 | Many Shades of Black | 0.019600 | 0.408 | 264627.0 | 0.728 | 0.0000 | G | 0.462 | -4.287 | Minor | 0.0340 | 95.438 | 3-Apr | 0.632 | Blues |
| 84270 | Post To Be (feat. Chris Brown & Jhene Aiko) | 0.069700 | 0.733 | 226581.0 | 0.676 | 0.0000 | A# | 0.208 | -5.655 | Minor | 0.0432 | 97.448 | 4-Apr | 0.701 | Rap |
| 36964 | Post to Be (feat. Chris Brown & Jhene Aiko) | 0.069700 | 0.733 | 226581.0 | 0.676 | 0.0000 | A# | 0.208 | -5.655 | Minor | 0.0432 | 97.448 | 4-Apr | 0.701 | Rap |
Повезло что тут не так много их, поступим просто, будем оставлять первую запись.
train_df.drop_duplicates(subset=['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode','music_genre'], inplace=True)
train_df[train_df.duplicated(
['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode','music_genre'],
keep=False
)].sort_values(by='track_name')
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id |
Далее сложней, нужно определить правильную классификацию, т.е. правильно промаркировать данные исходя из их параметров, и так как параметров не мало на глаз это сделать проблемно, пока примем решение почистить данные от обоих вариантов (далее можно попробовать восстановить часть данных в борьбе за лучшую метрику).
train_df.drop_duplicates(subset=['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode'], keep=False, inplace=True)
train_df[train_df.duplicated(
['acousticness','danceability','duration_ms','energy','instrumentalness',
'liveness','loudness','speechiness','tempo','valence','key','mode'],
keep=False
)].sort_values(by='track_name')
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id |
Теперь с дубликатами покончено.
Устранение выбросов¶
- В
duration_ms
print(f"доля отрицательных значений в duration_ms: {train_df.query('duration_ms == -1.0').shape[0] / train_df.shape[0]:.2%}")
доля отрицательных значений в duration_ms: 10.26%
Заменим на пропуски, а далее с помощью IterativeImputer уже в конвеере заполним их.
train_df.loc[train_df.duration_ms == -1.0, 'duration_ms'] = np.nan
train_df[train_df.duration_ms.isna()]
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 73565 | Iwanakutemo Tsutawaru Arewa Sukoshi Usoda - Al... | 0.11000 | 0.470 | NaN | 0.6410 | 0.000000 | F# | 0.1550 | -7.454 | Major | 0.0658 | 86.387 | 4-Apr | 0.638 | Anime |
| 76827 | Flodgin | 0.08480 | 0.748 | NaN | 0.7120 | 0.000000 | D | 0.1320 | -5.651 | Major | 0.1430 | 140.055 | 4-Apr | 0.180 | Rap |
| 43888 | Things My Father Said | 0.05400 | 0.482 | NaN | 0.6270 | 0.000003 | D | 0.1110 | -5.779 | Major | 0.0300 | 95.903 | 4-Apr | 0.170 | Alternative |
| 39574 | Fidelio, Op. 72, Act I: Ha! Welch ein Augenbli... | 0.90900 | 0.374 | NaN | 0.2270 | 0.000069 | A# | 0.6860 | -18.719 | Major | 0.0579 | 78.236 | 4-Apr | 0.264 | Classical |
| 42479 | American Grown | 0.04810 | 0.615 | NaN | 0.7790 | 0.000000 | C | 0.2300 | -9.577 | Major | 0.0832 | 147.987 | 4-Apr | 0.899 | Country |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 51087 | Bachianas brasileiras No. 7: III. Toccata: Des... | 0.90900 | 0.283 | NaN | 0.1330 | 0.800000 | B | 0.0927 | -25.586 | Minor | 0.0412 | 116.376 | 4-Apr | 0.177 | Classical |
| 48460 | Airborne - MUST DIE! Remix | 0.00298 | 0.450 | NaN | 0.9150 | 0.446000 | F | 0.7710 | -1.404 | Minor | 0.2360 | 137.948 | 4-Apr | 0.156 | Electronic |
| 23677 | Idomeneo, TrV 262, Act I (After W.A. Mozart): ... | 0.98200 | 0.372 | NaN | 0.1590 | 0.000001 | G | 0.1650 | -21.129 | Major | 0.0364 | 135.617 | 3-Apr | 0.512 | Classical |
| 89832 | I Want My Milk (I Want it Now) | 0.96600 | 0.638 | NaN | 0.0689 | 0.000001 | E | 0.1020 | -16.222 | Major | 0.0659 | 120.160 | 3-Apr | 0.695 | Blues |
| 73890 | 理性と力 | 0.62700 | 0.461 | NaN | 0.4360 | 0.892000 | G | 0.1330 | -9.381 | Major | 0.0301 | 130.035 | 4-Apr | 0.264 | Anime |
1867 rows × 15 columns
- Все другие поля
В конвеере к числовым полям применим фильтр-функцию
def outlier_detector_(X, y, consm=0.01, random_state=RANDOM_STATE):
outlier_index = IsolationForest(contamination=consm, random_state=random_state).fit_predict(X)
return X[outlier_index == 1], y[outlier_index == 1]
Разработка новых синтетических признаков¶
- Метрики текстовых полей
def make_col_txt_f3(data, col_target='track_name'):
'''Создание признака по текстовому полю - отношение количества слов к количеству символов'''
data.loc[:, 'txt_f3'] = data.apply(
lambda x: len(re.findall(r'\b\w+\b', x[col_target])) / len(x[col_target]),
axis=1
)
return data
Можно было бы использовать какие-нибудь продвинутые библиотеки, но для простоты решения было решено сделать простой подсчет символов и слов, и их отношение взять как метрику.
make_col_txt_f3(train_df)
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | txt_f3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | ||||||||||||||||
| 25143 | Highwayman | 0.48000 | 0.670 | 182653.0 | 0.351 | 0.017600 | D | 0.1150 | -16.842 | Major | 0.0463 | 101.384 | 4-Apr | 0.450 | Country | 0.100000 |
| 26091 | Toes Across The Floor | 0.24300 | 0.452 | 187133.0 | 0.670 | 0.000051 | A | 0.1080 | -8.392 | Minor | 0.0352 | 113.071 | 4-Apr | 0.539 | Rock | 0.190476 |
| 77021 | No Te Veo - Digital Single | 0.05580 | 0.847 | 255987.0 | 0.873 | 0.000003 | G# | 0.3250 | -4.805 | Minor | 0.0804 | 116.007 | 4-Apr | 0.966 | Hip-Hop | 0.192308 |
| 20852 | Chasing Shadows | 0.22700 | 0.742 | 195333.0 | 0.575 | 0.000002 | C | 0.1760 | -5.550 | Major | 0.0487 | 76.494 | 4-Apr | 0.583 | Alternative | 0.133333 |
| 43934 | Eskimo Blue Day - Remastered | 0.10200 | 0.308 | 392893.0 | 0.590 | 0.371000 | D | 0.1120 | -11.703 | Major | 0.0345 | 145.758 | 4-Apr | 0.496 | Blues | 0.142857 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 63232 | Wait (The Whisper Song) | 0.00112 | 0.933 | 179160.0 | 0.513 | 0.002480 | G | 0.1070 | -13.203 | Major | 0.3470 | 102.017 | 4-Apr | 0.595 | Rap | 0.173913 |
| 44799 | Mt. Washington | 0.19000 | 0.482 | 198933.0 | 0.362 | 0.005620 | F# | 0.0913 | -10.358 | Minor | 0.0299 | 76.879 | 4-Apr | 0.174 | Rock | 0.142857 |
| 33350 | Original Prankster | 0.00061 | 0.663 | 220947.0 | 0.886 | 0.000025 | D | 0.2840 | -4.149 | Major | 0.0358 | 146.803 | 4-Apr | 0.942 | Alternative | 0.111111 |
| 77920 | 4Peat | 0.00310 | 0.914 | 162214.0 | 0.515 | 0.000000 | C# | 0.1050 | -9.934 | Major | 0.3560 | 150.016 | 4-Apr | 0.215 | Rap | 0.200000 |
| 86375 | Trouble (feat. MC Spyder) | 0.05350 | 0.717 | 271885.0 | 0.983 | 0.491000 | D# | 0.1080 | -1.615 | Minor | 0.0942 | 128.004 | 4-Apr | 0.077 | Electronic | 0.160000 |
18193 rows × 16 columns
Проверка на мультиколлинеарность¶
# Корреляции признаков
corr_table = train_df.drop(['track_name'], axis=1).phik_matrix(verbose=False)
sns.heatmap(corr_table, xticklabels=True, yticklabels=True)
plt.title("Корреляция признаков (phik методом)")
plt.show()
corr_table
| acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | music_genre | txt_f3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| acousticness | 1.000000 | 0.461339 | 0.139702 | 0.767967 | 0.443488 | 0.097909 | 0.190217 | 0.710900 | 0.024274 | 0.208299 | 0.340002 | 0.262095 | 0.399151 | 0.660448 | 0.049976 |
| danceability | 0.461339 | 1.000000 | 0.180987 | 0.517242 | 0.341350 | 0.097089 | 0.139472 | 0.483874 | 0.101906 | 0.331755 | 0.410321 | 0.287712 | 0.535905 | 0.626858 | 0.042418 |
| duration_ms | 0.139702 | 0.180987 | 1.000000 | 0.150584 | 0.182019 | 0.041234 | 0.096930 | 0.180323 | 0.009962 | 0.030624 | 0.093674 | 0.046330 | 0.146016 | 0.209800 | 0.020408 |
| energy | 0.767967 | 0.517242 | 0.150584 | 1.000000 | 0.439477 | 0.090752 | 0.263851 | 0.839780 | 0.032611 | 0.239502 | 0.391116 | 0.293427 | 0.483151 | 0.672398 | 0.042251 |
| instrumentalness | 0.443488 | 0.341350 | 0.182019 | 0.439477 | 1.000000 | 0.026233 | 0.144268 | 0.500141 | 0.090778 | 0.187259 | 0.217113 | 0.134467 | 0.359236 | 0.567779 | 0.065385 |
| key | 0.097909 | 0.097089 | 0.041234 | 0.090752 | 0.026233 | 1.000000 | 0.058193 | 0.083143 | 0.359860 | 0.111959 | 0.043887 | 0.070907 | 0.043970 | 0.170350 | 0.036009 |
| liveness | 0.190217 | 0.139472 | 0.096930 | 0.263851 | 0.144268 | 0.058193 | 1.000000 | 0.205375 | 0.028544 | 0.143479 | 0.080353 | 0.047722 | 0.132633 | 0.214020 | 0.022575 |
| loudness | 0.710900 | 0.483874 | 0.180323 | 0.839780 | 0.500141 | 0.083143 | 0.205375 | 1.000000 | 0.017396 | 0.208346 | 0.358523 | 0.255918 | 0.442265 | 0.680271 | 0.037050 |
| mode | 0.024274 | 0.101906 | 0.009962 | 0.032611 | 0.090778 | 0.359860 | 0.028544 | 0.017396 | 1.000000 | 0.111484 | 0.013916 | 0.062834 | 0.057962 | 0.301252 | 0.040758 |
| speechiness | 0.208299 | 0.331755 | 0.030624 | 0.239502 | 0.187259 | 0.111959 | 0.143479 | 0.208346 | 0.111484 | 1.000000 | 0.184547 | 0.114062 | 0.115671 | 0.497328 | 0.000000 |
| tempo | 0.340002 | 0.410321 | 0.093674 | 0.391116 | 0.217113 | 0.043887 | 0.080353 | 0.358523 | 0.013916 | 0.184547 | 1.000000 | 0.129419 | 0.269249 | 0.340332 | 0.000000 |
| obtained_date | 0.262095 | 0.287712 | 0.046330 | 0.293427 | 0.134467 | 0.070907 | 0.047722 | 0.255918 | 0.062834 | 0.114062 | 0.129419 | 1.000000 | 0.194623 | 0.246471 | 0.000000 |
| valence | 0.399151 | 0.535905 | 0.146016 | 0.483151 | 0.359236 | 0.043970 | 0.132633 | 0.442265 | 0.057962 | 0.115671 | 0.269249 | 0.194623 | 1.000000 | 0.470556 | 0.034757 |
| music_genre | 0.660448 | 0.626858 | 0.209800 | 0.672398 | 0.567779 | 0.170350 | 0.214020 | 0.680271 | 0.301252 | 0.497328 | 0.340332 | 0.246471 | 0.470556 | 1.000000 | 0.145798 |
| txt_f3 | 0.049976 | 0.042418 | 0.020408 | 0.042251 | 0.065385 | 0.036009 | 0.022575 | 0.037050 | 0.040758 | 0.000000 | 0.000000 | 0.000000 | 0.034757 | 0.145798 | 1.000000 |
Признак loudness к energy коррелирует аж в 0.839780!
Также признак acousticness к energy коррелирует в 0.767967.
Можно попробовать оценить как повлияет на метрики удаление признака energy (но как удалось проверить - метрика f1 лишь упала, возможно надо перебрать заново гиперпараметры под новые условия, но перебирая параметры порога в объекте DropCorrelatedFeatures в нашем конвеере при подборе гиперпараметров (и не только их) было выявлено лучшим порогом 0.9, а значит эти признаки не были удалены)
Отбор финального набора обучающих признаков¶
# Категориальные признаки
cat_colunmns = ['key', 'mode']
# Числовые признаки
num_columns = [
'acousticness', 'danceability', 'instrumentalness','energy',
'liveness', 'loudness', 'speechiness', 'tempo', 'valence', 'duration_ms',
'txt_f3'
]
# Удаленные признаки
del_columns = ['track_name', 'obtained_date']
Удалили колонку даты (obtained_date) так как она не несет полезной информации, и нелогично вообще ее использовать т.к. она не характеризует сами объекты исследования. Из колонки названия произведения (track_name) мы извлекли метрику по тексту, сама же колонка нам очень избыточна и не очень полезна для обучения.
Подготовка данных для работы¶
X = train_df.drop(columns=del_columns+['music_genre'])
y = train_df.music_genre
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.25, random_state=RANDOM_STATE, stratify=y) # равномерно распределим классы в выбоке задав параметр stratify
print(f"Записей в train выборке: {X_train.shape[0]/X.shape[0]:.0%}")
print(f"Записей в valid выборке: {X_valid.shape[0]/X.shape[0]:.0%}")
X_train.head()
Записей в train выборке: 75% Записей в valid выборке: 25%
| acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | valence | txt_f3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||
| 38582 | 0.81800 | 0.402 | NaN | 0.261 | 0.000053 | F | 0.6430 | -8.874 | Minor | 0.0315 | 75.413 | 0.214 | 0.155556 |
| 59383 | 0.66100 | 0.454 | 309227.0 | 0.328 | 0.035600 | D# | 0.1090 | -12.699 | Major | 0.0339 | 154.375 | 0.503 | 0.250000 |
| 89312 | 0.66900 | 0.626 | NaN | 0.495 | 0.000000 | F | 0.1120 | -8.282 | Major | 0.0449 | 81.012 | 0.434 | 0.214286 |
| 65051 | 0.40600 | 0.754 | 249507.0 | 0.496 | 0.000002 | B | 0.0899 | -7.357 | Minor | 0.0450 | 89.970 | 0.701 | 0.250000 |
| 85606 | 0.00724 | 0.550 | 178878.0 | 0.907 | 0.000000 | G | 0.5180 | -2.996 | Major | 0.1010 | 96.051 | 0.537 | 0.125000 |
y.value_counts().plot(kind='bar')
plt.title("Баланс классов в исходных данных")
plt.show()
y_train.value_counts().plot(kind='bar')
plt.title("Баланс классов в тренировочных данных")
plt.show()
y_valid.value_counts().plot(kind='bar')
plt.title("Баланс классов в валидационных данных")
plt.show()
Видим что стратификация отработала как надо и мы сохранили такое-же соотношение классов, пусть и все равно несбалансированно, зато этот баланс не усугубился.
Создание конвеера¶
Цепочка первичных преобразований данных¶
num_pipeline = make_imblearn_pipeline(IterativeImputer(), StandardScaler(), DropCorrelatedFeatures(threshold=0.9))
#cat_pipeline = make_imblearn_pipeline(SimpleImputer(missing_values=np.nan, strategy='most_frequent'))
#почемуто не заработало для catboost без OrdinalEncoder
cat_pipeline = make_imblearn_pipeline(SimpleImputer(missing_values=np.nan, strategy='most_frequent'), OrdinalEncoder())
column_transformer = make_column_transformer(
(num_pipeline, num_columns),
(cat_pipeline, cat_colunmns),
remainder='passthrough'
)
column_transformer
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])['acousticness', 'danceability', 'instrumentalness', 'energy', 'liveness', 'loudness', 'speechiness', 'tempo', 'valence', 'duration_ms', 'txt_f3']
IterativeImputer()
StandardScaler()
DropCorrelatedFeatures(threshold=0.9)
['key', 'mode']
SimpleImputer(strategy='most_frequent')
OrdinalEncoder()
passthrough
Основной конвеер¶
pipeline = make_imblearn_pipeline(
column_transformer, # трансформатор данных, чтобы они были масштабированы и лишены пропусков
FunctionSampler(func=outlier_detector_), # для устранения выбросов (если такие есть)
RandomForestClassifier(random_state=RANDOM_STATE)
)
pipeline
Pipeline(steps=[('columntransformer',
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness',
'danceability',
'instrumentalness', 'energy',
'liveness', 'loudness',
'sp...ss', 'tempo',
'valence', 'duration_ms',
'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])),
('functionsampler',
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)),
('randomforestclassifier',
RandomForestClassifier(random_state=69))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('columntransformer',
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness',
'danceability',
'instrumentalness', 'energy',
'liveness', 'loudness',
'sp...ss', 'tempo',
'valence', 'duration_ms',
'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])),
('functionsampler',
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)),
('randomforestclassifier',
RandomForestClassifier(random_state=69))])ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])['acousticness', 'danceability', 'instrumentalness', 'energy', 'liveness', 'loudness', 'speechiness', 'tempo', 'valence', 'duration_ms', 'txt_f3']
IterativeImputer()
StandardScaler()
DropCorrelatedFeatures(threshold=0.9)
['key', 'mode']
SimpleImputer(strategy='most_frequent')
OrdinalEncoder()
passthrough
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)
RandomForestClassifier(random_state=69)
Нахождение лучшей модели¶
Подготовка к перебору¶
# Узнаем названия этапов, которые были автоматически заданны
pipeline.named_steps
{'columntransformer': ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])]),
'functionsampler': FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>),
'randomforestclassifier': RandomForestClassifier(random_state=69)}
Поменяем имя стадии где мы выбираем классификатор, логичнее его назвать как-то обобщающе, например сокращенно clf
# Поменяем название конечного этапа на общее название - clf (классификатор)
pipeline.steps[-1] = ('clf', pipeline.steps[-1][1])
# Проверим что название поменялось
pipeline.named_steps
{'columntransformer': ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])]),
'functionsampler': FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>),
'clf': RandomForestClassifier(random_state=69)}
# Создание объекта класса перебора гиперпараметров, назначение параметров перебора
max_depth_range = range(1, 16)
estimators_range = range(10, 100)
params_search = [
{
'clf': [RandomForestClassifier(random_state=RANDOM_STATE)],
'clf__n_estimators': estimators_range,
'clf__max_depth': max_depth_range,
'clf__max_features': [None, 'sqrt', 5],
'clf__class_weight': ["balanced"],
'columntransformer__pipeline-1__dropcorrelatedfeatures__threshold': [0.7, 0.8, 0.9]
},
{
'clf': [CatBoostClassifier(random_state=RANDOM_STATE, verbose=False, auto_class_weights='Balanced')],
'clf__n_estimators': estimators_range,
'clf__max_depth': max_depth_range,
# 'clf__task_type': ['GPU'],
'columntransformer__pipeline-1__dropcorrelatedfeatures__threshold': [0.7, 0.8, 0.9]
}
]
def print_grid_result(rf_grid):
'''Вывода результата перебора.'''
print("Лучший классификатор:", rf_grid.best_estimator_)
print("Лучшая метрика:", rf_grid.best_score_)
print("Лучшие параметры:\n", rf_grid.best_params_)
Перебор & обучение¶
%%time
# Перебор (или подстановка уже готовых параметров в случае не поиска)
if IS_BEST_MODEL_NOT_FOUND:
grid = RandomizedSearchCV(
pipeline,
params_search,
n_iter=30,
cv=3,
verbose=1,
random_state=RANDOM_STATE,
scoring='f1_micro',
n_jobs=1
).fit(X_train, y_train)
# Сохраним лучший классификатор в отдельную переменную
best_estimator = grid.best_estimator_
print_grid_result(grid)
else:
pipeline.steps[-1] = ('clf', CatBoostClassifier(max_depth=4, n_estimators=71, auto_class_weights='Balanced', random_state=RANDOM_STATE))
best_estimator = pipeline.fit(X_train, y_train)
best_estimator
Learning rate set to 0.5 0: learn: 1.9345361 total: 63.9ms remaining: 4.47s 1: learn: 1.8346161 total: 71.9ms remaining: 2.48s 2: learn: 1.7248385 total: 79.3ms remaining: 1.8s 3: learn: 1.6593859 total: 87.3ms remaining: 1.46s 4: learn: 1.6314436 total: 94.5ms remaining: 1.25s 5: learn: 1.5966725 total: 103ms remaining: 1.11s 6: learn: 1.5643687 total: 110ms remaining: 1s 7: learn: 1.5414249 total: 117ms remaining: 922ms 8: learn: 1.5278853 total: 126ms remaining: 866ms 9: learn: 1.5116770 total: 133ms remaining: 809ms 10: learn: 1.4931345 total: 140ms remaining: 765ms 11: learn: 1.4839330 total: 148ms remaining: 726ms 12: learn: 1.4785739 total: 155ms remaining: 692ms 13: learn: 1.4706867 total: 163ms remaining: 662ms 14: learn: 1.4641184 total: 172ms remaining: 641ms 15: learn: 1.4543192 total: 182ms remaining: 624ms 16: learn: 1.4420643 total: 189ms remaining: 602ms 17: learn: 1.4321181 total: 197ms remaining: 580ms 18: learn: 1.4274889 total: 206ms remaining: 563ms 19: learn: 1.4230624 total: 214ms remaining: 545ms 20: learn: 1.4200508 total: 222ms remaining: 528ms 21: learn: 1.4141543 total: 230ms remaining: 512ms 22: learn: 1.4027947 total: 241ms remaining: 503ms 23: learn: 1.3971262 total: 250ms remaining: 489ms 24: learn: 1.3915324 total: 258ms remaining: 474ms 25: learn: 1.3882737 total: 266ms remaining: 461ms 26: learn: 1.3804916 total: 275ms remaining: 447ms 27: learn: 1.3752791 total: 283ms remaining: 434ms 28: learn: 1.3702914 total: 290ms remaining: 420ms 29: learn: 1.3647757 total: 298ms remaining: 407ms 30: learn: 1.3614998 total: 307ms remaining: 396ms 31: learn: 1.3569641 total: 315ms remaining: 384ms 32: learn: 1.3521652 total: 323ms remaining: 372ms 33: learn: 1.3476726 total: 331ms remaining: 360ms 34: learn: 1.3447543 total: 339ms remaining: 349ms 35: learn: 1.3404884 total: 347ms remaining: 337ms 36: learn: 1.3352323 total: 355ms remaining: 326ms 37: learn: 1.3312477 total: 362ms remaining: 314ms 38: learn: 1.3284753 total: 370ms remaining: 304ms 39: learn: 1.3254137 total: 380ms remaining: 295ms 40: learn: 1.3218046 total: 390ms remaining: 285ms 41: learn: 1.3182278 total: 402ms remaining: 278ms 42: learn: 1.3152363 total: 411ms remaining: 268ms 43: learn: 1.3124813 total: 420ms remaining: 258ms 44: learn: 1.3101565 total: 429ms remaining: 248ms 45: learn: 1.3081553 total: 438ms remaining: 238ms 46: learn: 1.3034417 total: 445ms remaining: 227ms 47: learn: 1.3011884 total: 452ms remaining: 217ms 48: learn: 1.2984344 total: 459ms remaining: 206ms 49: learn: 1.2953997 total: 466ms remaining: 196ms 50: learn: 1.2935085 total: 472ms remaining: 185ms 51: learn: 1.2907936 total: 480ms remaining: 175ms 52: learn: 1.2886016 total: 487ms remaining: 165ms 53: learn: 1.2862064 total: 495ms remaining: 156ms 54: learn: 1.2845462 total: 502ms remaining: 146ms 55: learn: 1.2823068 total: 509ms remaining: 136ms 56: learn: 1.2801764 total: 517ms remaining: 127ms 57: learn: 1.2762660 total: 526ms remaining: 118ms 58: learn: 1.2736514 total: 533ms remaining: 108ms 59: learn: 1.2687247 total: 540ms remaining: 99.1ms 60: learn: 1.2672878 total: 548ms remaining: 89.8ms 61: learn: 1.2659987 total: 555ms remaining: 80.5ms 62: learn: 1.2641192 total: 562ms remaining: 71.3ms 63: learn: 1.2617295 total: 572ms remaining: 62.6ms 64: learn: 1.2603823 total: 579ms remaining: 53.5ms 65: learn: 1.2592011 total: 588ms remaining: 44.5ms 66: learn: 1.2568708 total: 596ms remaining: 35.6ms 67: learn: 1.2534956 total: 604ms remaining: 26.6ms 68: learn: 1.2505817 total: 612ms remaining: 17.7ms 69: learn: 1.2487430 total: 619ms remaining: 8.85ms 70: learn: 1.2458355 total: 628ms remaining: 0us CPU times: total: 359 ms Wall time: 1.4 s
Pipeline(steps=[('columntransformer',
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness',
'danceability',
'instrumentalness', 'energy',
'liveness', 'loudness',
'sp...', 'tempo',
'valence', 'duration_ms',
'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])),
('functionsampler',
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)),
('clf',
<catboost.core.CatBoostClassifier object at 0x000002971631FD30>)])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('columntransformer',
ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness',
'danceability',
'instrumentalness', 'energy',
'liveness', 'loudness',
'sp...', 'tempo',
'valence', 'duration_ms',
'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])),
('functionsampler',
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)),
('clf',
<catboost.core.CatBoostClassifier object at 0x000002971631FD30>)])ColumnTransformer(remainder='passthrough',
transformers=[('pipeline-1',
Pipeline(steps=[('iterativeimputer',
IterativeImputer()),
('standardscaler',
StandardScaler()),
('dropcorrelatedfeatures',
DropCorrelatedFeatures(threshold=0.9))]),
['acousticness', 'danceability',
'instrumentalness', 'energy', 'liveness',
'loudness', 'speechiness', 'tempo', 'valence',
'duration_ms', 'txt_f3']),
('pipeline-2',
Pipeline(steps=[('simpleimputer',
SimpleImputer(strategy='most_frequent')),
('ordinalencoder',
OrdinalEncoder())]),
['key', 'mode'])])['acousticness', 'danceability', 'instrumentalness', 'energy', 'liveness', 'loudness', 'speechiness', 'tempo', 'valence', 'duration_ms', 'txt_f3']
IterativeImputer()
StandardScaler()
DropCorrelatedFeatures(threshold=0.9)
['key', 'mode']
SimpleImputer(strategy='most_frequent')
OrdinalEncoder()
[]
passthrough
FunctionSampler(func=<function outlier_detector_ at 0x0000029717C4BC10>)
<catboost.core.CatBoostClassifier object at 0x000002971631FD30>
Лучшие параметры составного объекта (конвеера) были найдены такие:
DropCorrelatedFeatures(threshold=0.9)
CatBoostClassifier(max_depth=4, n_estimators=71, auto_class_weights='Balanced')
CatBoostClassifier по данным показал лучшие результаты по метрике f1_micro.
Оценка модели на валидационной выборке¶
predicted = best_estimator.predict(X_valid)
print("f1_micro:", f1_score(y_valid, predicted, average='micro').round(3))
f1_micro: 0.454
sk_report = classification_report(
digits=4,
y_true=y_valid,
y_pred=predicted
)
print(sk_report)
precision recall f1-score support
Alternative 0.4086 0.2709 0.3258 561
Anime 0.4843 0.4854 0.4848 445
Blues 0.5117 0.4343 0.4698 654
Classical 0.7442 0.7926 0.7676 323
Country 0.4164 0.5685 0.4807 482
Electronic 0.6441 0.6094 0.6263 594
Hip-Hop 0.2628 0.4402 0.3291 234
Jazz 0.2973 0.4615 0.3616 286
Rap 0.4673 0.3883 0.4242 515
Rock 0.2500 0.1912 0.2167 455
accuracy 0.4542 4549
macro avg 0.4487 0.4642 0.4487 4549
weighted avg 0.4625 0.4542 0.4517 4549
Видно, что хуже всего дела с жанром Rock, модели сложней всего его определить, а лучше всего модель справляется с Classical и Electronic, видимо маловато признаков, возможно стоит уделить время на синтез вторичных признаков, чтобы попытаться улучшить показатели в ходших классах (нужно понять чем рок отличается от других классов по параметрам, может распросить профессиональных музыкантов, кто знаком с этой сферой хорошо)
Получение предсказайний на тестовых данных¶
Загрузка тестовых данных¶
test_df = load_dataset(PATH_DATASET_TEST)
Файл "kaggle_music_genre_test.csv" успешно загружен.
| instance_id | acousticness | danceability | duration_ms | energy | instrumentalness | liveness | loudness | speechiness | tempo | valence | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 5099.000000 | 5099.000000 | 5099.000000 | 5.099000e+03 | 5099.000000 | 5099.000000 | 5099.000000 | 5099.000000 | 5099.000000 | 4978.000000 | 5099.000000 |
| mean | 55643.871347 | 0.276324 | 0.561888 | 2.173974e+05 | 0.622030 | 0.165198 | 0.197924 | -8.630186 | 0.092718 | 121.246463 | 0.465593 |
| std | 20762.384803 | 0.322657 | 0.170502 | 1.156374e+05 | 0.251829 | 0.311940 | 0.166241 | 5.619766 | 0.100130 | 29.875950 | 0.244217 |
| min | 20012.000000 | 0.000002 | 0.059600 | -1.000000e+00 | 0.001540 | 0.000000 | 0.020400 | -46.122000 | 0.022600 | 37.114000 | 0.020500 |
| 25% | 37571.000000 | 0.016600 | 0.451000 | 1.737335e+05 | 0.465000 | 0.000000 | 0.096950 | -10.231000 | 0.035700 | 96.070250 | 0.272000 |
| 50% | 55246.000000 | 0.120000 | 0.568000 | 2.175000e+05 | 0.660000 | 0.000157 | 0.129000 | -7.135000 | 0.048700 | 120.053500 | 0.458000 |
| 75% | 73702.000000 | 0.460000 | 0.681000 | 2.642470e+05 | 0.826000 | 0.092750 | 0.248000 | -5.127000 | 0.096550 | 141.934250 | 0.650000 |
| max | 91709.000000 | 0.996000 | 0.977000 | 1.360027e+06 | 0.999000 | 0.994000 | 0.990000 | 1.949000 | 0.918000 | 216.029000 | 0.982000 |
| instance_id | track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 48564 | Low Class Conspiracy | 0.301000 | 0.757 | 146213.0 | 0.679 | 0.00000 | A# | 0.3030 | -7.136 | Minor | 0.3560 | 90.361 | 4-Apr | 0.895 |
| 1 | 72394 | The Hunter | 0.538000 | 0.256 | 240360.0 | 0.523 | 0.00832 | G# | 0.0849 | -5.175 | Major | 0.0294 | 78.385 | 4-Apr | 0.318 |
| 2 | 88081 | Hate Me Now | 0.005830 | 0.678 | 284000.0 | 0.770 | 0.00000 | A | 0.1090 | -4.399 | Minor | 0.2220 | 90.000 | 4-Apr | 0.412 |
| 3 | 78331 | Somebody Ain't You | 0.020300 | 0.592 | 177354.0 | 0.749 | 0.00000 | B | 0.1220 | -4.604 | Major | 0.0483 | 160.046 | 4-Apr | 0.614 |
| 4 | 72636 | Sour Mango | 0.000335 | 0.421 | -1.0 | 0.447 | 0.01480 | D | 0.0374 | -8.833 | Major | 0.2020 | 73.830 | 4-Apr | 0.121 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 5094 | 50532 | What We Gonna Do About It | 0.108000 | 0.558 | 163049.0 | 0.767 | 0.00000 | E | 0.0954 | -4.561 | Minor | 0.0491 | 158.019 | 4-Apr | 0.715 |
| 5095 | 26255 | Marilyn (feat. Dominique Le Jeune) | 0.131000 | 0.435 | 196216.0 | 0.641 | 0.00000 | A# | 0.2730 | -7.274 | Major | 0.1040 | 115.534 | 3-Apr | 0.156 |
| 5096 | 67924 | Bipolar | 0.152000 | 0.756 | 243373.0 | 0.787 | 0.00000 | D | 0.2050 | -7.423 | Major | 0.2400 | 123.405 | 4-Apr | 0.459 |
| 5097 | 79778 | Dead - NGHTMRE Remix | 0.001450 | 0.489 | 185600.0 | 0.974 | 0.63800 | F# | 0.1230 | -2.857 | Minor | 0.0381 | 150.036 | 4-Apr | 0.665 |
| 5098 | 47986 | A Night In Tunisia - Remastered 1998 / Rudy Va... | 0.715000 | 0.538 | 256800.0 | 0.520 | 0.01750 | G | 0.0676 | -10.127 | Minor | 0.0408 | 83.816 | 4-Apr | 0.764 |
5099 rows × 15 columns
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5099 entries, 0 to 5098 Data columns (total 15 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 instance_id 5099 non-null int64 1 track_name 5099 non-null object 2 acousticness 5099 non-null float64 3 danceability 5099 non-null float64 4 duration_ms 5099 non-null float64 5 energy 5099 non-null float64 6 instrumentalness 5099 non-null float64 7 key 4941 non-null object 8 liveness 5099 non-null float64 9 loudness 5099 non-null float64 10 mode 4950 non-null object 11 speechiness 5099 non-null float64 12 tempo 4978 non-null float64 13 obtained_date 5099 non-null object 14 valence 5099 non-null float64 dtypes: float64(10), int64(1), object(4) memory usage: 597.7+ KB
show_miss(test_df)
Оставшиеся пропуски (%):
key 3.1 mode 2.9 tempo 2.4 dtype: float64
Данные также без проблем загружены, и имеют пропуски, к слову в тех же колонках!
Подготовка данных (в приемлемый вид)¶
# назначим instance_id индексной колонкой
test_df = test_df.set_index('instance_id', drop=True)
test_df.head()
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | ||||||||||||||
| 48564 | Low Class Conspiracy | 0.301000 | 0.757 | 146213.0 | 0.679 | 0.00000 | A# | 0.3030 | -7.136 | Minor | 0.3560 | 90.361 | 4-Apr | 0.895 |
| 72394 | The Hunter | 0.538000 | 0.256 | 240360.0 | 0.523 | 0.00832 | G# | 0.0849 | -5.175 | Major | 0.0294 | 78.385 | 4-Apr | 0.318 |
| 88081 | Hate Me Now | 0.005830 | 0.678 | 284000.0 | 0.770 | 0.00000 | A | 0.1090 | -4.399 | Minor | 0.2220 | 90.000 | 4-Apr | 0.412 |
| 78331 | Somebody Ain't You | 0.020300 | 0.592 | 177354.0 | 0.749 | 0.00000 | B | 0.1220 | -4.604 | Major | 0.0483 | 160.046 | 4-Apr | 0.614 |
| 72636 | Sour Mango | 0.000335 | 0.421 | -1.0 | 0.447 | 0.01480 | D | 0.0374 | -8.833 | Major | 0.2020 | 73.830 | 4-Apr | 0.121 |
Видно что в duration_ms также имеются отрицательные значения, но они все заполнятся т.к. в нашем конвеере есть для этого специальный объект для такого случая -IterativeImputer, он реализует многомерные алгоритмы восстановления пропущенных значений, оценивая другие значения в наборе.
# добавление синтетического признака
make_col_txt_f3(test_df)
| track_name | acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | obtained_date | valence | txt_f3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||||
| 48564 | Low Class Conspiracy | 0.301000 | 0.757 | 146213.0 | 0.679 | 0.00000 | A# | 0.3030 | -7.136 | Minor | 0.3560 | 90.361 | 4-Apr | 0.895 | 0.150000 |
| 72394 | The Hunter | 0.538000 | 0.256 | 240360.0 | 0.523 | 0.00832 | G# | 0.0849 | -5.175 | Major | 0.0294 | 78.385 | 4-Apr | 0.318 | 0.200000 |
| 88081 | Hate Me Now | 0.005830 | 0.678 | 284000.0 | 0.770 | 0.00000 | A | 0.1090 | -4.399 | Minor | 0.2220 | 90.000 | 4-Apr | 0.412 | 0.272727 |
| 78331 | Somebody Ain't You | 0.020300 | 0.592 | 177354.0 | 0.749 | 0.00000 | B | 0.1220 | -4.604 | Major | 0.0483 | 160.046 | 4-Apr | 0.614 | 0.222222 |
| 72636 | Sour Mango | 0.000335 | 0.421 | -1.0 | 0.447 | 0.01480 | D | 0.0374 | -8.833 | Major | 0.2020 | 73.830 | 4-Apr | 0.121 | 0.200000 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 50532 | What We Gonna Do About It | 0.108000 | 0.558 | 163049.0 | 0.767 | 0.00000 | E | 0.0954 | -4.561 | Minor | 0.0491 | 158.019 | 4-Apr | 0.715 | 0.240000 |
| 26255 | Marilyn (feat. Dominique Le Jeune) | 0.131000 | 0.435 | 196216.0 | 0.641 | 0.00000 | A# | 0.2730 | -7.274 | Major | 0.1040 | 115.534 | 3-Apr | 0.156 | 0.147059 |
| 67924 | Bipolar | 0.152000 | 0.756 | 243373.0 | 0.787 | 0.00000 | D | 0.2050 | -7.423 | Major | 0.2400 | 123.405 | 4-Apr | 0.459 | 0.142857 |
| 79778 | Dead - NGHTMRE Remix | 0.001450 | 0.489 | 185600.0 | 0.974 | 0.63800 | F# | 0.1230 | -2.857 | Minor | 0.0381 | 150.036 | 4-Apr | 0.665 | 0.150000 |
| 47986 | A Night In Tunisia - Remastered 1998 / Rudy Va... | 0.715000 | 0.538 | 256800.0 | 0.520 | 0.01750 | G | 0.0676 | -10.127 | Minor | 0.0408 | 83.816 | 4-Apr | 0.764 | 0.161290 |
5099 rows × 15 columns
# удаление ненужных признаков
X_test = test_df.drop(columns=del_columns)
X_test.head()
| acousticness | danceability | duration_ms | energy | instrumentalness | key | liveness | loudness | mode | speechiness | tempo | valence | txt_f3 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| instance_id | |||||||||||||
| 48564 | 0.301000 | 0.757 | 146213.0 | 0.679 | 0.00000 | A# | 0.3030 | -7.136 | Minor | 0.3560 | 90.361 | 0.895 | 0.150000 |
| 72394 | 0.538000 | 0.256 | 240360.0 | 0.523 | 0.00832 | G# | 0.0849 | -5.175 | Major | 0.0294 | 78.385 | 0.318 | 0.200000 |
| 88081 | 0.005830 | 0.678 | 284000.0 | 0.770 | 0.00000 | A | 0.1090 | -4.399 | Minor | 0.2220 | 90.000 | 0.412 | 0.272727 |
| 78331 | 0.020300 | 0.592 | 177354.0 | 0.749 | 0.00000 | B | 0.1220 | -4.604 | Major | 0.0483 | 160.046 | 0.614 | 0.222222 |
| 72636 | 0.000335 | 0.421 | -1.0 | 0.447 | 0.01480 | D | 0.0374 | -8.833 | Major | 0.2020 | 73.830 | 0.121 | 0.200000 |
Получение предсказаний¶
predicted = best_estimator.predict(X_test)
print("Выходной размер массива предсказаний:", np.shape(predicted))
predicted
Выходной размер массива предсказаний: (5099, 1)
array([['Hip-Hop'],
['Anime'],
['Hip-Hop'],
...,
['Hip-Hop'],
['Anime'],
['Jazz']], dtype=object)
Экспорт результа в файл¶
# Обвернем результат в датафрейм
result = pd.DataFrame(predicted, columns=['music_genre'], index=X_test.index)
result
| music_genre | |
|---|---|
| instance_id | |
| 48564 | Hip-Hop |
| 72394 | Anime |
| 88081 | Hip-Hop |
| 78331 | Country |
| 72636 | Anime |
| ... | ... |
| 50532 | Country |
| 26255 | Rap |
| 67924 | Hip-Hop |
| 79778 | Anime |
| 47986 | Jazz |
5099 rows × 1 columns
# Экспорт
result.to_csv("prediction_results.csv")
# Проверка что файл создался и мы его можем прочитать
test_result_read = pd.read_csv('prediction_results.csv')
test_result_read
| instance_id | music_genre | |
|---|---|---|
| 0 | 48564 | Hip-Hop |
| 1 | 72394 | Anime |
| 2 | 88081 | Hip-Hop |
| 3 | 78331 | Country |
| 4 | 72636 | Anime |
| ... | ... | ... |
| 5094 | 50532 | Country |
| 5095 | 26255 | Rap |
| 5096 | 67924 | Hip-Hop |
| 5097 | 79778 | Anime |
| 5098 | 47986 | Jazz |
5099 rows × 2 columns
Анализ важности признаков в предсказаниях модели¶
# Установим в TreeExplainer (объяснитель дерева) модель конвейера
model = best_estimator['clf']
explainer = shap.TreeExplainer(model)
# Применим предварительную обработку к X_test
observations = pipeline['columntransformer'].transform(X_test)
observations
array([[ 0.06560431, 1.14557447, -0.52979559, ..., -0.72703201,
1. , 1. ],
[ 0.79555719, -1.76608186, -0.50295821, ..., 0.06054829,
11. , 0. ],
[-0.84351042, 0.68645102, -0.52979559, ..., 1.20611963,
0. , 1. ],
...,
[-0.39331121, 1.13976278, -0.52979559, ..., -0.83954348,
5. , 0. ],
[-0.85700069, -0.41195826, 1.52816718, ..., -0.72703201,
9. , 1. ],
[ 1.34071187, -0.12718548, -0.47334677, ..., -0.5491913 ,
10. , 1. ]])
# Получим значения Shap из предварительно обработанных данных
shap_values = explainer.shap_values(observations)
np.shape(shap_values)
(10, 5099, 13)
# Построим график важности признаков для каждого класса
shap.summary_plot(shap_values, X_test, plot_type="bar", class_names=model.classes_)
Country и Electronic отличаются как-то от всех особой продолжительностью своих произведений, танцевальность влияет на классику что неожиданно. Больше всего из признаков влияет продолжительность и танцевальность, ну и на 3-м месте признак живого исполнения произведения (liveness).
Общий вывод¶
Основное:
- В ходе работ была разработана модель, позволяющую классифицировать музыкальные произведения по жанрам, в качестве лучшего классификатора было решено взять CatBoostClassifier, он дал лучшую метрику при подборе гиперпараметров.
- Были опробованы новые ранее незнакомые инструменты для работы с данными (например такие как
shap,IterativeImputer,DropCorrelatedFeatures, ... ). - Была освоена методология работы с конвеером (pipeline).
- В ходе работ был создан файл с ответами модели по тестовым данным, который был использован в соревновании на платформе Kaggle 12 место с f1=0.47137.
Замечания:
- В поле
duration_msв более 10% данных были выбросы константным значением в -1.0, боло решено считать это за пропуск, но на всякий случай лучше уточнить этот момент. - В полях
key,modeиtempoчто в тренировочных, что в тестовых данных - наблюдаются пропуски, необходимо понять из-за чего они возникают (если бы это была полноценная работа).